#
# Counterfactuals for baby model with NIPT
#

#### Set-Up ####

rm(list = ls())
Sys.info()[7]
# Set file paths 

WORKING <- "MYPATH"
DATA <- paste0(WORKING, "/Data")
TEMP <- paste0(WORKING, "/Temp")
RESULTS <- paste0(WORKING, "/Results")


# Go to code folder
setwd("MYPATH/estimate_w_xs")

sink("counterfactuals_xs.txt")

folder <- paste0(TEMP, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}
folder <- paste0(RESULTS, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}

Sys.info()[7]
Sys.time()


condition_data <- 0
expand <- 1


# Packages
.libPaths(c("MYPATH/R-library"))
library(MASS)
library(tmvtnorm)
library(haven)
library(dplyr)
library(tidyr)
library(nloptr)
library(corpcor)
library(stargazer)
library(pracma)
library(foreign)
library(writexl)
library(broom)
library(png)
library(tidyverse)
library(compositions) # for rlnorm.rplus
library(data.table) # for rbindlist


#### Load Data + Assumptions ####

analysis <- read_dta(paste0(DATA, "/model_sample.dta")) %>%
  mutate(oop_i = ifelse(is.na(oop_i)==TRUE, Inf, oop_i)) %>%
  mutate(age_25_34 = ifelse(age >= 25 & age < 35, 1, 0)) %>%
  mutate(age_35_plus = ifelse(age>35,1,0)) %>%
  mutate(some_college = ifelse(educ == 1 | is.na(educ), 0, 1)) %>%
  mutate(missing_college = ifelse(is.na(educ), 1, 0)) %>%
  mutate(any_prev_birth_issue = ifelse(prev_concern == 1 | prev_any_q_icd == 1 | prev_mis_still == 1, 1,0)) %>%
  mutate(inc_quartile_4 = ifelse(inc_quartile != 4 | is.na(inc_quartile), 0, 1)) %>%
  mutate(full_college = ifelse(educ != 3 | is.na(educ), 0, 1))

pa <- 0.005
pfp <- 0.01
pfn <- 0.01

J <- 1 

print("Begin time")
start_time <- Sys.time()
print(start_time)

#### Load Necessary Model Functions ####

source("r_functions/production_model_xs_new.R")
source("r_functions/counterfactual_functions.R")


# Function that calculates decisions for each observation i 
  # NEW: WITHOUT DRAWING A+C (i.e., taking a+c as given)
  # NEW: TAKING REC AS GIVEN 
  # NEW: ADD POSSIBLE OOP cost FOR INVASIVE = oop2_i 
  # NEW: allow mother's prior to be different from the truth (where truth=p_i=KUB)

#### Simulate data based on parameter estimates + model ####
analysis <- analysis %>%
  mutate(full_college_norm = (full_college - mean(full_college))) %>%
  mutate(fc_n = (full_college - mean(full_college))) %>%
  mutate(inc_quartile_4_norm = (inc_quartile_4 - mean(inc_quartile_4))) %>%
  mutate(married_norm = (married - mean(married))) %>%
  mutate(mom_foreign_norm = (mom_foreign - mean(mom_foreign))) %>%
  mutate(any_prev_birth_issue_norm = (any_prev_birth_issue - mean(any_prev_birth_issue))) %>%
  mutate(prev_kids_norm = (dv_prev_kids - mean(dv_prev_kids))) %>%
  mutate(age_35_plus_norm = (age_35_plus - mean(age_35_plus)))
f_a = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
f_c = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
x_vars <- c("full_college_norm", "inc_quartile_4_norm", "married_norm", "mom_foreign_norm", "any_prev_birth_issue_norm", "age_35_plus_norm", "prev_kids_norm")

est_par <- read.csv(paste0(RESULTS, "/estimate_w_xs/results_no_xs.csv")) %>%
  as.matrix() 
est_par <- est_par[-21:-22,]
print(est_par)

use_full_sample <- 0
if (use_full_sample == 1) {
  full_sample <- read_dta(paste0(DATA, "/analysis_sample.dta")) 
  full_Xs <- read_dta(paste0(DATA, "/all_screens_universalnt_w_xs.dta")) %>%
    rename(inc_quartile2 = inc_quartile) %>%
    rename(age2 = age) %>%
    rename(educ2 = educ) %>%
    rename(mom_foreign2 = mom_foreign) %>%
    dplyr::select(pregnancy, prev_mis_still, age2, educ2, prev_concern, prev_any_q_icd, inc_quartile2, married, mom_foreign2)
  full_sample <- full_sample %>%
    inner_join(full_Xs, by = "pregnancy") %>%
    mutate(is_right_inc = (inc_quartile == inc_quartile2)) %>%
    mutate(is_right_age = (age == age2)) %>%
    mutate(is_right_ed = (educ == educ2)) %>%
    mutate(is_right_for = (mom_foreign == mom_foreign2)) %>%
    mutate(age_25_34 = ifelse(age >= 25 & age < 35, 1, 0)) %>%
    mutate(age_35_plus = ifelse(age>35,1,0)) %>%
    mutate(some_college = ifelse(educ == 1 | is.na(educ), 0, 1)) %>%
    mutate(missing_college = ifelse(is.na(educ), 1, 0)) %>%
    mutate(any_prev_birth_issue = ifelse(prev_concern == 1 | prev_any_q_icd == 1 | prev_mis_still == 1, 1,0)) %>%
    mutate(inc_quartile_4 = ifelse(inc_quartile != 4 | is.na(inc_quartile), 0, 1)) %>%
    mutate(full_college = ifelse(educ != 3 | is.na(educ), 0, 1))
  full_sample = subset(full_sample, select = -c(is_right_inc, is_right_age, is_right_ed, is_right_for, inc_quartile2, age2, educ2, mom_foreign2))
  oop <- analysis %>%
    select(pregnancy, oop_i)
  full_sample <- full_sample %>%
    left_join(oop, by = "pregnancy") %>%
    mutate(oop_i = ifelse(is.na(oop_i)==TRUE, ifelse(wave == 2, Inf, 567.5), oop_i)) %>%
    rename(p_i = kub_score)
  
  full_sample <- full_sample %>%
    mutate(bin_number = 1) %>%
    mutate(policy_id = 1) %>%
    mutate(policy_regime = 1) %>%
    mutate(ave_kub_age = 1)
  
  full_sample <- full_sample %>%
    mutate(full_college_norm = (full_college - mean(full_college))) %>%
    mutate(inc_quartile_4_norm = (inc_quartile_4 - mean(inc_quartile_4))) %>%
    mutate(married_norm = (married - mean(married))) %>%
    mutate(mom_foreign_norm = (mom_foreign - mean(mom_foreign))) %>%
    mutate(any_prev_birth_issue_norm = (any_prev_birth_issue - mean(any_prev_birth_issue)))

}


## calculate true positive indicator for chrom. ab. = live birth with chrom. ab. (using actual, not simulated, data), and simulate positive indicator for counterfactuals

### what should we put in actual?
if (use_full_sample == 1) {
  actual <- full_sample
} else {
  actual <- analysis
}

#Step 0a: expand data K times
if (expand == 1) {
  K <- 1000
  if (K > 1) {
    getData <- function(){
      return(data.frame(actual))
    }
    actual <- data.frame(rbindlist(lapply(1:K,function(x) getData())))
  }
}

#Step 0b: simulate positive indicator for chromAb 
set.seed(1292837)
actual <- actual %>% 
  mutate(draw = runif(nrow(.), min=0, max=1)) %>%
  mutate(sim_positive = ifelse(draw <= p_i, 1, 0)) %>% 
  select(-draw)

#Step 0c: simulate result of NIPT test, if taken (using simulated chromAb indicator)
# added new mutate(sim_nipt_result) before select(-draw) to address dropping pregnancies in conditional counterfactual
actual <- actual %>% 
  mutate(draw = runif(nrow(.), min=0, max=1)) %>%
  mutate(sim_nipt_result = ifelse(sim_positive == 1,
                               ifelse(draw > pfn, 1, 0),
                               ifelse(draw <= pfp, 1, 0))) %>%
  select(-draw)

  ###### CHECKS FOR SIM_NIPT_RESULT ######
  print("sim_nipt_results after changing procedure")
  print("table of actual$sim_positive")
  table(actual$sim_positive)
  pos_ca <- actual %>%
    filter(sim_positive == 1)
  neg_test_ca <- pos_ca %>%
    filter(sim_nipt_result == 0)
  print("pos_ca$sim_nipt_result")
  table(pos_ca$sim_nipt_result)
  pos_ca_count <- nrow(pos_ca)
  neg_test_pos_ca <- nrow(neg_test_ca)
  false_neg_rate <- neg_test_pos_ca/pos_ca_count
  print("false negative rate")
  print(false_neg_rate)
  
  neg_ca <- actual %>%
    filter(sim_positive == 0)
  pos_test_no_ca <- neg_ca %>%
    filter(sim_nipt_result == 1)
  print("neg_ca$sim_nipt_result")
  table(neg_ca$sim_nipt_result)
  neg_ca_count <- nrow(neg_ca)
  pos_test_neg_ca <- nrow(pos_test_no_ca)
  false_pos_rate <- pos_test_neg_ca/neg_ca_count
  print("false positive rate")
  print(false_pos_rate)
  
  rm(pos_ca, neg_test_ca, neg_ca, pos_test_no_ca)
  table(actual$sim_nipt_result, useNA = "ifany")
  #################

# Step 1: simulate K (a_i, c_i) draws per pregnancy (using parameter estimates) + calculate predicted testing decisions
# note: condition on the data = discard (a_i, c_i) draws that are rejected by the observed testing decisions
set.seed(101)
  # create an exception flag for those in wave 2 that got NIPT
  # why?: we can't condition on the data for these guys' NIPT decision because the model will never predict that outcome
  actual <- actual %>%
    mutate(exception = ifelse(wave==2 & did_nipt==1, 1, 0))
  
  if (condition_data == 1){
    # simulate - first iteration
    iter = 1
    draw <- draw_a_c_counter_xs(par=est_par, data=actual, f_a=f_a, f_c=f_c, x_vars = x_vars)
    draw <- as.data.frame(draw) %>%
      mutate(oop2_i = 0)
    simulated <- model_decisions_counterfactuals(par=est_par, data=draw, prior=quo(p_i), setup = 1) %>%
      as.data.frame() %>%
      dplyr::select(a_i, c_i, pred_nipt, pred_invasive)
  
    all <- cbind(actual, simulated) %>%
      mutate(keeper = ifelse(did_nipt==pred_nipt,
                             1,
                             ifelse(exception==1, 1, 0))) %>%
      mutate(keeper = ifelse(did_nipt==1,
                             keeper,
                             ifelse(did_invasive==pred_invasive, 1, 0))) %>%
      group_by(pregnancy) %>%
      mutate(share_keeper = mean(keeper)) %>%
      mutate(num_keeper = sum(keeper)) %>%
      ungroup()
  
    rm(actual, simulated, draw)
  
      # keep first iteration in a separate database
      first <- all
  
    # simulate - next iterations
    
    while(iter < 250){
      iter = iter + 1
      
      good <- all %>% filter(share_keeper >= 0.1 | keeper==1) %>%
        dplyr::select(-keeper, -share_keeper)
  
      bad_act <- all %>% filter(share_keeper < 0.1 & keeper==0) %>%
        dplyr::select(-a_i, -c_i, -pred_nipt, -pred_invasive, -keeper, -share_keeper)
      bad_draw <- draw_a_c_counter_xs(par=est_par, data=bad_act, f_a=f_a, f_c=f_c, x_vars=x_vars)
      bad_draw <- as.data.frame(bad_draw) %>%
        mutate(oop2_i = 0)
      
      bad_sim <- model_decisions_counterfactuals(par=est_par, data=bad_draw, prior=quo(p_i), setup = 1) %>%
        as.data.frame() %>%
        dplyr::select(a_i, c_i, pred_nipt, pred_invasive)
  
      bad <- cbind(bad_act, bad_sim)
  
      all <- rbind(good, bad) %>%
        mutate(keeper = ifelse(did_nipt==pred_nipt,
                              1,
                              ifelse(exception==1, 1, 0))) %>%
        mutate(keeper = ifelse(did_nipt==1,
                              keeper,
                              ifelse(did_invasive==pred_invasive, 1, 0))) %>%
        arrange(pregnancy) %>%
        group_by(pregnancy) %>%
        mutate(share_keeper = mean(keeper)) %>%
        mutate(num_keeper = sum(keeper)) %>%
        ungroup()
      rm(good, bad_act, bad_draw,  bad_sim, bad)
  
    }
    rm(iter)
    
    # keep only keepers
    summary(all$share_keeper)
    print("all$keeper")
    print(table(all$keeper))
  
    bad <- all %>% 
      filter(num_keeper == 0) %>%
      dplyr::select(pregnancy, did_nipt, did_invasive, fetus_risk, pred_invasive, pred_nipt,  p_i, lan, wave, oop_i, age, sim_positive, a_i, c_i, sim_nipt_result)
    
    ### Write dataset with the bad draws
    if (use_full_sample == 0) {
      unique_bad <- bad %>%
        dplyr::select(pregnancy, did_nipt, did_invasive, p_i, fetus_risk, oop_i) 
      unique_bad <- unique(unique_bad)
      write.dta(unique_bad, file=paste0(TEMP, "/estimate_w_xs/unique_bad_draws.dta"))
      
      ### Generate stats on pregnancies that we lose from conditioning before changing the conditioning procedure
      bad_id <- bad %>%
        select(pregnancy,p_i)
      unique_bad <- unique(bad_id)
      all_unique <- all %>%
        select(pregnancy, p_i, num_keeper) 
      all_unique <- unique(all_unique)
      preg_count <-nrow(all_unique)
      bad_preg_count <- nrow(filter(all_unique, num_keeper == 0))
      good_preg_count <- nrow(filter(all_unique, num_keeper > 0))
      share_bad_unique <- bad_preg_count / preg_count
      share_good_unique <- 1 - share_bad_unique
      num_bad_low_risk <- nrow(filter(all_unique, num_keeper == 0 & p_i < 1/200))
      num_all_low_risk <-nrow(filter(all_unique, p_i < 1/200))
      num_good_low_risk <- nrow(filter(all_unique, num_keeper > 0 & p_i < 1/200))
      share_bad_low_risk <- num_bad_low_risk / bad_preg_count
      share_all_low_risk <- num_all_low_risk / preg_count
      share_good_low_risk <- num_good_low_risk / good_preg_count
      unique_preg_stats <- matrix(NA, nrow = 3, ncol = 4)
      unique_preg_stats[1,] <- c(bad_preg_count, share_bad_unique, num_bad_low_risk, share_bad_low_risk)
      unique_preg_stats[2,] <- c(good_preg_count, share_good_unique, num_good_low_risk, share_good_low_risk)
      unique_preg_stats[3,] <- c(preg_count, 1, num_all_low_risk, share_all_low_risk)
      write.csv(unique_preg_stats, file=paste0(RESULTS, "/estimate_w_xs/pregnancy_stats.csv"), na=".", row.names=TRUE)
      #####
    
      bad <- bad %>% mutate(pref_chrom = ifelse(c_i >= a_i, 1, 0))
    }
  } else {
    draw <- draw_a_c_counter_xs(par=est_par, data=actual, f_a=f_a, f_c=f_c, x_vars = x_vars)
    draw <- as.data.frame(draw) %>%
      mutate(oop2_i = 0)
    simulated <- model_decisions_counterfactuals(par=est_par, data=draw, prior=quo(p_i), setup = 1) %>%
      as.data.frame() %>%
      dplyr::select(a_i, c_i, pred_nipt, pred_invasive)
    all <- cbind(actual, simulated)
  }
  
  if (condition_data == 1) {
    print("Condition on data")
    all <- all %>% 
      filter(keeper==1)
    # create simulation weight = 1 / # keepers (same for every draw within a pregnancy)
    all <- all %>% mutate(sim_wgt = 1 / (share_keeper * K))
    table(all$keeper)
    
    weight_stats <- all %>%
      select(pregnancy, sim_wgt, share_keeper)
    unique_wt_stats <- unique(weight_stats)
    unique_wt_stats <- unique_wt_stats %>%
      mutate(all_keepers = ifelse(share_keeper == 1, 1, 0))
    print("unique weight stats")
    print(table(unique_wt_stats$all_keepers))
  } else {
    print("No condition on data") #  to check what happens if we don't condition on data
    all <- all %>% mutate(sim_wgt = 1) 
  }
 
# Step 2: for each observation, randomly draw whether invasive testing will cause a miscarriage (unknown to mother)
all <- all %>%
  mutate(draw = runif(nrow(.), min=0, max=1)) %>%
  mutate(sim_inv_mis = ifelse(draw <= pa, 1, 0)) %>% 
  select(-draw) 
  

# Step 3: clean up
  # only keep necessary variables 
  # + create new oop cost for invasive variable (=0 for all, which is actual)
  # + create new rec variable (= actual, for now)
  all <- all %>%
    dplyr::select(pregnancy, did_nipt, did_invasive, live_birth, p_i, wave, oop_i, age, ave_kub_age, sim_positive, a_i, c_i, sim_wgt, sim_inv_mis, sim_nipt_result, fetus_risk) %>%
    mutate(oop2_i = 0) %>%
    mutate(rec_i = ifelse(fetus_risk <= 200, 1, 0)) %>% 
    mutate(low = ifelse(fetus_risk > 200, 1, 0)) %>%
    mutate(med = ifelse(fetus_risk <= 200 & p_i >= 51, 1, 0)) %>%
    mutate(high = ifelse(fetus_risk <= 51, 1, 0))
  
  
  print("Weighted share of accepted draws with c_i greater than a_i" )
  all <- all %>% mutate(c_g_a = ifelse(c_i > a_i, 1, 0))
  print(weighted.mean(all$c_g_a, all$sim_wgt))
  
if (condition_data == 1) {
  rm(first)
}

#### Calc counterfactual decisions and outcomes ####
  
  # n pregnancies
  n_preg <- nrow(analysis)
  share_low <- weighted.mean(all$low, all$sim_wgt)
  share_med <- weighted.mean(all$med, all$sim_wgt)
  share_high <- weighted.mean(all$high, all$sim_wgt)
  sum_wgts <- sum(all$sim_wgt)
  
  
  # Set cost to government of tests (in dollars)
  g_cost_inv <- 1248.50 # 11,000 SEK 
  g_cost_nipt <- 567.5 # 5000 SEK
  g_cost_nt <- 174 #1,500 SEK

  
# Counterfactual 0 = model predictions
# Implement settings
  set <- all
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  # Calculate shares of outcomes 
  cf0 <- outcomes_shares(data=out)
  rm(set, dec, out) 
  
# Counterfactual 1 = mothers already know whether fetus has chrom ab
# settings: SPECIAL - don't use "outcomes" function to calc outcomes
#   no testing
#   abort if positive==1 & a_i > c_i
  
  # Predict decisions
  dec <- all %>% 
    mutate(pred_nipt = 0) %>%
    mutate(pred_invasive = 0)
  
  # Calculate outcomes 
  out <- dec %>%
    mutate(sim_abort = ifelse((sim_positive==1 & (a_i > c_i)), 1, 0)) %>% 
    mutate(sim_live = ifelse(sim_abort==1, 0, 1)) %>% 
    mutate(sim_mis = 0) %>% 
    mutate(sim_terminate = sim_abort + sim_mis) %>%
    mutate(sim_nipt = pred_nipt) %>% # just renaming 
    mutate(sim_invasive = pred_invasive) %>% # just renaming
    mutate(sim_util = ifelse(sim_live == 0, a_i,
                             ifelse(sim_positive == 1, c_i, 0))) %>% 
    mutate(sim_util = ifelse(pred_nipt == 0, sim_util - (oop2_i * pred_invasive), sim_util - oop_i - oop2_i * (pred_invasive))) %>%
    dplyr::select(sim_wgt, sim_positive, sim_nipt, sim_invasive, sim_live, sim_abort, sim_mis, sim_terminate, sim_util, a_i, c_i, oop_i, oop2_i, p_i, fetus_risk)  
  
  out_low <- out %>% filter(fetus_risk >200)
  out_med <- out %>% filter(fetus_risk <= 200 & fetus_risk > 50)
  out_high <- out %>% filter(fetus_risk <= 50)
  
  # Calculate shares of outcomes 
  cf1 <- outcomes_shares(data=out)
  cf1_low <- outcomes_shares_split(data=out_low)
  cf1_med <- outcomes_shares_split(data=out_med)
  cf1_high <- outcomes_shares_split(data=out_high)
  
  rm(dec, out, out_low, out_med, out_high)
  
  # Calculate cost to govt from testing
  cf1[14] = 0
  cf1[15] = 0
  cf1[16] = 0
  
  
# Counterfactual 2 = everyone must not get any testing
# settings: 
#   oop_i = Inf (for all)
#   oop2_i = Inf (for all) 
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = Inf) %>% 
    mutate(oop2_i = Inf)
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  out_low <- out %>% filter(fetus_risk >200)
  out_med <- out %>% filter(fetus_risk <= 200 & fetus_risk > 50)
  out_high <- out %>% filter(fetus_risk <= 50)
  
  # Calculate shares of outcomes 
  cf2 <- outcomes_shares(data=out)
  cf2_low <- outcomes_shares_split(data=out_low)
  cf2_med <- outcomes_shares_split(data=out_med)
  cf2_high <- outcomes_shares_split(data=out_high)
  
  rm(set, dec, out, out_low, out_med, out_high)
  
  # normalize consumer surplus to be relative to no testing CS
  surplus_norm <- cf2[17]
  cf1[17] = cf1[17] - surplus_norm
  cf2[17] = cf2[17] - surplus_norm

# Counterfactual 3 = everyone must not get NIPT; invasive offered to KUB >= 1/200 only (and free to them)
# settings: 
#   oop_i = Inf (for all)
#   oop2_i = Inf for KUB < 1/200, oop2_i = 0 for KUB >= 1/200
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = Inf) %>% 
    mutate(oop2_i = ifelse(fetus_risk <= 200, 0, Inf))
  
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out3 <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  cf3 <- outcomes_shares(data=out3)
  rm(set, dec, out3)
  
  # normalize consumer surplus to be relative to no testing CS
  cf3[17] = cf3[17] - surplus_norm
  
# Counterfactual 4 = everyone must not get NIPT; invasive offered to everyone (and free to them)
# settings: 
#   oop_i = Inf (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = Inf) %>% 
    mutate(oop2_i = 0)  

  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  out_low <- out %>% filter(fetus_risk > 200)
  out_med <- out %>% filter(fetus_risk <= 200 & fetus_risk > 50)
  out_high <- out %>% filter(fetus_risk <= 50)
  
  # Calculate shares of outcomes 
  cf4 <- outcomes_shares(data=out)
  cf4_low <- outcomes_shares_split(data=out_low)
  cf4_med <- outcomes_shares_split(data=out_med)
  cf4_high <- outcomes_shares_split(data=out_high)
  
  # save decisions for plots of testing rates bycounterfactual
  if (use_full_sample == 1) {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/inv_only_cf_out_full.dta"))
  } else {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/inv_only_cf_out.dta"))
  }
  rm(set, dec, out, out_low, out_med, out_high)
  
  # normalize consumer surplus to be relative to no testing CS
  cf4[17] = cf4[17] - surplus_norm
  
# Counterfactual 5 = everyone gets free testing
# settings: 
#   oop_i = 0 (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0)
  
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))

  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  out_low <- out %>% filter(fetus_risk > 200)
  out_med <- out %>% filter(fetus_risk <= 200 & fetus_risk > 50)
  out_high <- out %>% filter(fetus_risk <= 50)
  
  # Calculate shares of outcomes 
  cf5 <- outcomes_shares(data=out)
  cf5_low <- outcomes_shares_split(data=out_low)
  cf5_med <- outcomes_shares_split(data=out_med)
  cf5_high <- outcomes_shares_split(data=out_high)
  
  # save decisions for plots of testing rates by counterfactual
  if (use_full_sample == 1) {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/free_nipt_cf_out_full.dta"))
  } else {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/free_nipt_cf_out.dta"))
  }
 
  rm(set, dec, out, out_low, out_med, out_high)  
  
  # normalize consumer surplus to be relative to no testing CS
  cf5[17] = cf5[17] - surplus_norm
  
  
# Counterfactual 6 = everyone pay for own testing
# settings: 
#   oop_i = 5000 SEK ~= 567.50 USD (for all)
#   oop2_i = 11000 SEK ~= 1248.50 USD (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = 567.50) %>% 
    mutate(oop2_i = 1248.50)
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  cf6 <- outcomes_shares(data=out)
  rm(set, dec, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf6[17] = cf6[17] - surplus_norm
  
# Counterfactual 7 = everyone pay for own NIPT but get free invasive
# settings: 
#   oop_i = 5000 SEK ~= 567.50 USD (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = 567.50) %>% 
    mutate(oop2_i = 0)
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  out_low <- out %>% filter(fetus_risk > 200)
  out_med <- out %>% filter(fetus_risk <= 200 & fetus_risk > 50)
  out_high <- out %>% filter(fetus_risk <= 50)
  
  
  # Calculate shares of outcomes 
  cf7 <- outcomes_shares(data=out)
  cf7_low <- outcomes_shares_split(data=out_low)
  cf7_med <- outcomes_shares_split(data=out_med)
  cf7_high <- outcomes_shares_split(data=out_high)
  
  # save decisions for plots of testing rates by
  if (use_full_sample == 1){
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/oop_cf_out_full.dta"))
  } else {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/oop_cf_out.dta"))
  }
  
  
  rm(set, dec, out, out_low, out_med, out_high)
  
  # normalize consumer surplus to be relative to no testing CS
  cf7[17] = cf7[17] - surplus_norm
  
# Counterfactual 8 = free invasive for all, free NIPT for 1/200 <= KUB <= 1/51, pay for own NIPT otherwise
# settings: 
#   oop_i = 0 for 1/200 <= KUB <= 1/51, oop_i = 5000 SEK ~= 567.50 USD for others
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = ifelse((fetus_risk <= 200) & (fetus_risk > 50), 0, 567.50)) %>% 
    mutate(oop2_i = 0)
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  cf8 <- outcomes_shares(data=out)
  rm(set, dec, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf8[17] = cf8[17] - surplus_norm
  
  
# Counterfactual 9 = free invasive for all, free NIPT for 1/200 <= KUB, pay for own NIPT otherwise
# settings: 
#   oop_i = 0 for 1/200 <= KUB, oop_i = 5000 SEK ~= 567.50 USD for others
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = KUB score)
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = ifelse(fetus_risk <= 200, 0, 567.50)) %>% 
    mutate(oop2_i = 0) 
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  cf9 <- outcomes_shares(data=out)
  
  # save decisions for plots of testing rates by 
  if (use_full_sample == 1) {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/g_200_cf_out_full.dta"))
  } else {
    write.dta(out, file=paste0(TEMP, "/estimate_w_xs/g_200_cf_out.dta"))
  }
 
  rm(set, dec, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf9[17] = cf9[17] - surplus_norm
  
  
# Counterfactual 10 = NIPT directly, without KUB
# settings: 
#   oop_i = 0 (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = ave_kub_age (i.e., prior = ave kub by age != truth = own kub)
#   additional: everyone must get NIPT 
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0)
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(ave_kub_age))
    # revise so everyone gets NIPT 
    # then recalc invasive probability based on pred_nipt=1 for all 
    dec <- dec %>%
      mutate(pred_nipt = 1) %>%
      mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>%
      mutate(pred_invasive = ifelse(sim_nipt_result == 1, invasive_P, invasive_N)) %>%
      dplyr::select(-true_pr_pos)
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(ave_kub_age))
  
  # Calculate shares of outcomes 
  cf10 <- outcomes_shares(data=out)
  cf10[14] = cf10[14] - g_cost_nt # Subtract cost of NT (bc skip it) from total testing cost
  rm(set, dec, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf10[17] = cf10[17] - surplus_norm
  
  
# Counterfactual 11 = NIPT directly, without KUB, if woman's age >= 32
# settings: 
# for women aged >= 32: 
#   oop_i = 0 (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = ave_kub_age (i.e., prior = ave kub by age != truth = own kub)
#   additional: must get NIPT 
# for women aged < 32: 
#   oop_i = 5000 SEK ~= 567.50 USD (for all)
#   oop2_i = 0 (for all)
#   rec_i = 0 (for all)
#   prior = p_i (i.e., prior = truth = own kub)
#   additional: must get NIPT 
  
  # Implement settings
  set_old <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age >= 32)
   

  set_young <- all %>%
    mutate(oop_i = 567.50) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age < 32)
  
  # Predict decisions
  dec_old <- model_decisions_counterfactuals(par=est_par, data=set_old, prior=quo(ave_kub_age))
    # revise so everyone gets NIPT 
    # then recalc invasive probability based on pred_nipt=1 for all 
    dec_old <- dec_old %>%
      mutate(pred_nipt = 1) %>%
      mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>%
      mutate(pred_invasive = ifelse(sim_nipt_result == 1, invasive_P, invasive_N)) %>%
      dplyr::select(-true_pr_pos)
  
  dec_young <- model_decisions_counterfactuals(par=est_par, data=set_young, prior=quo(p_i))
  
  # Calculate outcomes 
  out_old <- outcomes(par=est_par, data=dec_old, prior=quo(ave_kub_age))
  
  out_young <- outcomes(par=est_par, data=dec_young, prior=quo(p_i))
  
  out <- bind_rows(out_old, out_young)
  
  # Calculate shares of outcomes 
  cf11 <- outcomes_shares(data=out)
  cf11[14] = cf11[14] - g_cost_nt # Subtract cost of NT (bc skip it) from total testing cost
  rm(set_old, set_young, dec_old, dec_young, out_old, out_young, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf11[17] = cf11[17] - surplus_norm
  
if (use_full_sample == 1){
  # Counterfactual 12 = everyone gets free testing, skipping NT but with NT score as prior
  # settings: 
  #   oop_i = 0 (for all)
  #   oop2_i = 0 (for all)
  #   rec_i = 0 (for all)
  #   prior = p_i (i.e., prior = truth = KUB score)
  #   subtract NT cost from govt testing cost
  
  # Implement settings
  set <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0)
  
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  cf12 <- outcomes_shares(data=out)
  
  rm(set, dec, out)  
  
  # normalize consumer surplus to be relative to no testing CS
  cf12[14] = cf12[14] - g_cost_nt # Subtract cost of NT (bc skip it) from total testing cost
  cf12[17] = cf12[17] - surplus_norm
  
  
  # Counterfactual 13 = CF12, but free NIPT for age >=32 only, younger have out of pocket cost
  # settings: 
  #   oop_i = 0 (for women aged 32 or older)
  #   oop_i = 5000 SEK ~= 567.50 USD (for women younger than 32)
  #   oop2_i = 0 (for all)
  #   rec_i = 0 (for all)
  #   prior = p_i (i.e., prior = truth = KUB score)
  #   subtract NT cost from govt testing cost
  
  # Implement settings
  # Implement settings
  set_old <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age >= 32)
  
  
  set_young <- all %>%
    mutate(oop_i = 567.50) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age < 32) 
  
  # Predict decisions
  dec_old <- model_decisions_counterfactuals(par=est_par, data=set_old, prior=quo(p_i))
  # revise so everyone gets NIPT 
  # then recalc invasive probability based on pred_nipt=1 for all 
  # dec_old <- dec_old %>%
  #   mutate(pred_nipt = 1) %>%
  #   mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>%
  #   mutate(pred_invasive = ifelse(sim_nipt_result == 1, invasive_P, invasive_N)) %>%
  #   dplyr::select(-true_pr_pos)
  
  dec_young <- model_decisions_counterfactuals(par=est_par, data=set_young, prior=quo(p_i))
  
  # Calculate outcomes 
  out_old <- outcomes(par=est_par, data=dec_old, prior=quo(p_i))
  
  out_young <- outcomes(par=est_par, data=dec_young, prior=quo(p_i))
  
  out <- bind_rows(out_old, out_young)
  cf13 <- outcomes_shares(data=out)
  rm(set_old, set_young, dec_old, dec_young, out_old, out_young, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf13[14] = cf13[14] - g_cost_nt # Subtract cost of NT (bc skip it) from total testing cost
  cf13[17] = cf13[17] - surplus_norm
  
  
  # Counterfactual 14 = CF13, but must do nipt for old
  # settings: 
  #   oop_i = 0 (for women aged 32 or older)
  #   oop_i = 5000 SEK ~= 567.50 USD (for women younger than 32)
  #   oop2_i = 0 (for all)
  #   rec_i = 0 (for all)
  #   prior = p_i (i.e., prior = truth = KUB score)
  #   subtract NT cost from govt testing cost
  
  # Implement settings
  # Implement settings
  set_old <- all %>%
    mutate(oop_i = 0) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age >= 32)
  
  
  set_young <- all %>%
    mutate(oop_i = 567.50) %>% 
    mutate(oop2_i = 0) %>% 
    filter(age < 32) 
  
  # Predict decisions
  dec_old <- model_decisions_counterfactuals(par=est_par, data=set_old, prior=quo(p_i))
  # revise so everyone gets NIPT 
  # then recalc invasive probability based on pred_nipt=1 for all 
  dec_old <- dec_old %>%
    mutate(pred_nipt = 1) %>%
    mutate(true_pr_pos = round( (1-p_i)*pfp + p_i*(1-pfn) , 10)) %>%
    mutate(pred_invasive = ifelse(sim_nipt_result == 1, invasive_P, invasive_N)) %>%
    dplyr::select(-true_pr_pos)
  
  dec_young <- model_decisions_counterfactuals(par=est_par, data=set_young, prior=quo(p_i))
  
  # Calculate outcomes 
  out_old <- outcomes(par=est_par, data=dec_old, prior=quo(p_i))
  
  out_young <- outcomes(par=est_par, data=dec_young, prior=quo(p_i))
  
  out <- bind_rows(out_old, out_young)
  cf14 <- outcomes_shares(data=out)
  rm(set_old, set_young, dec_old, dec_young, out_old, out_young, out)
  
  # normalize consumer surplus to be relative to no testing CS
  cf14[14] = cf14[14] - g_cost_nt # Subtract cost of NT (bc skip it) from total testing cost
  cf14[17] = cf14[17] - surplus_norm
  
}  


# Full countefactuals table
out_cf <- rbind(cf0, cf1, cf2, cf3, cf4, cf5, cf6, cf7, cf8, cf9, cf10, cf11)
rownames(out_cf) <- c("Model", "First-best", "No testing", "c3", "Invasive only", "Free NIPT", "cf6", "Out-of-pocket NIPT", "[1:51-1:200]", " > 1:201", "cf10", "cf11")
colnames(out_cf) <- c(c("Any testing", "NIPT only", "Invasive only", "NIPT followed by invasive", 
                        "Live birth", "live, chrom ab", "live, chrom ab, a_i > c_i", "live, chrom ab, a_i < c_i", "live, no chrom ab", 
                        "Terminated", "term, chrom ab, a_i > c_i", "term, no chrom ab", 
                        "Bad outcome", 
                        "G cost (p.c.)", "G cost, NIPT", "G cost, Inv", "Consumer surplus", "Share NIPT oop"))

print(t(out_cf))
stargazer(t(out_cf), digits = 2)

# Value of NIPT countefactuals table
val_nipt_cf <- rbind(cf2, cf1, cf4, cf5)
val_nipt_cf[ , c(6,7,8,9)] <- val_nipt_cf[ , c(9, 6, 7, 8)] # put "live, no CA" ahead of live with CAs
val_nipt_cf[ , c(11,12)] <- val_nipt_cf[ , c(12,11)] # put "term, no CA" ahead of term with CA
val_nipt_cf <- val_nipt_cf[, -c(7)] # drop "live w/CA" from table
val_nipt_cf <- val_nipt_cf[, -c(12)] # drop "bad outcomes" from table
val_nipt_cf <- val_nipt_cf[, c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)]
rownames(val_nipt_cf) <- c("No testing", "First-best", "Invasive only", "Free NIPT")
colnames(val_nipt_cf) <- c(c("Any testing", "NIPT only", "Invasive only", "NIPT followed by invasive", 
                             "Live birth", "live, no CA", "live, CA, a_i > c_i", "live, CA, a_i < c_i", 
                             "Terminated", "term, no CA", "term, CA, a_i > c_i", 
                             "G cost (p.c.)", "G cost, NIPT", "G cost, Inv", "Consumer surplus"))
print(t(val_nipt_cf))
stargazer(t(val_nipt_cf), digits = 2)
val_nipt_cf <- t(val_nipt_cf)

val_nipt_cf <- data.frame(val_nipt_cf)
write.csv(val_nipt_cf, file=paste0(RESULTS, "/estimate_w_xs/val_nipt_table.csv"), na=".", row.names=TRUE)

if (use_full_sample == 1){
  supp_cf <- rbind(cf12, cf13, cf14)
  supp_cf[, c(6,7,8,9)] <- supp_cf[, c(9, 6, 7, 8)] # put "live, no CA" ahead of live with CAs
  supp_cf[, c(11,12)] <- supp_cf[, c(12,11)] # put "term, no CA" ahead of term with CA
  supp_cf<- supp_cf[, -c(7)] # drop "live w/CA" from table
  supp_cf <- supp_cf[, -c(12)] # drop "bad outcomes" from table
  supp_cf <- supp_cf[, c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)]
  supp_cf <- as.matrix(supp_cf)
  rownames(supp_cf) <- c("cf12", "cf13", "cf14")
  colnames(supp_cf) <- c(c("Any testing", "NIPT only", "Invasive only", "NIPT followed by invasive", 
                               "Live birth", "live, no CA", "live, CA, a_i > c_i", "live, CA, a_i < c_i", 
                               "Terminated", "term, no CA", "term, CA, a_i > c_i", 
                               "G cost (p.c.)", "G cost, NIPT", "G cost, Inv", "Consumer surplus"))
  print(t(supp_cf))
  stargazer(t(supp_cf), digits = 2)
  supp_cf <- t(supp_cf)
  
  supp_cf <- data.frame(supp_cf)
  write.csv(supp_cf, file=paste0(RESULTS, "/estimate_w_xs/supp_cf.csv"), na=".", row.names=TRUE)
}
print("Finished supp_cf")

# Ins cov of NIPT countefactuals table
ins_nipt_cf <- rbind(cf4, cf5, cf7, cf9, cf8)
ins_nipt_cf[ , c(6,7,8,9)] <- ins_nipt_cf[ , c(9, 6, 7, 8)] # put "live, no CA" ahead of live with CAs
ins_nipt_cf[ , c(11,12)] <- ins_nipt_cf[ ,c(12,11)] # put "term, no CA" ahead of term with CA
ins_nipt_cf <- ins_nipt_cf[, -c(7)] # drop "live w/CA" from table
ins_nipt_cf <- ins_nipt_cf[, -c(12)] # drop "bad outcomes" from table
ins_nipt_cf <- ins_nipt_cf[, c(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)]
rownames(ins_nipt_cf) <- c("Invasive only", "Free NIPT", "Out-of-pocket NIPT", "> 1:201", "[1:51-1:200]")
colnames(ins_nipt_cf) <- c(c("Any testing", "NIPT only", "Invasive only", "NIPT followed by invasive", 
                             "Live birth", "live, no CA", "live, CA, a_i > c_i", "live, CA, a_i < c_i", 
                             "Terminated", "term, no CA", "term, CA, a_i > c_i", 
                             "G cost (p.c.)", "G cost, NIPT", "G cost, Inv", "Consumer surplus"))

print(t(ins_nipt_cf))
stargazer(t(ins_nipt_cf), digits = 2)
ins_nipt_cf <- t(ins_nipt_cf)
ins_nipt_cf <- data.frame(ins_nipt_cf)
write.csv(ins_nipt_cf, file=paste0(RESULTS, "/estimate_w_xs/insurance_nipt_table.csv"), na=".", row.names=TRUE)

# Breakdown by low/med/high
split_by_risk <- rbind(cf2, cf2_low, cf2_med, cf2_high, 
                       cf1, cf1_low, cf1_med, cf1_high, 
                       cf4, cf4_low, cf4_med, cf4_high, 
                       cf7, cf7_low, cf7_med, cf7_high, 
                       cf5, cf5_low, cf5_med, cf5_high)
rownames(split_by_risk) <- c("No testing", "No testing-Low", "No testing-Med", "No testing-High", 
                             "First-best", "First-best-Low", "First-best-Med", "First-best-High",
                             "Invasive only", "Invasive only-Low", "Invasive only-Med", "Invasive only-High",
                             "Out-of-pocket NIPT", "Out-of-pocket NIPT-Low", "Out-of-pocket NIPT-Med", "Out-of-pocket NIPT-High",
                             "Free NIPT",  "Free NIPT-Low", "Free NIPT-Med", "Free NIPT-High")
colnames(split_by_risk) <- c(c("Any testing", "NIPT only", "Invasive only", "NIPT followed by invasive", 
                             "Live birth", "live, chrom ab", "live, chrom ab, a_i > c_i", "live, chrom ab, a_i < c_i", "live, no chrom ab", 
                             "Terminated", "term, chrom ab, a_i > c_i", "term, no chrom ab", 
                             "Bad outcome", 
                             "G cost (p.c.)", "G cost, NIPT", "G cost, Inv", "Consumer surplus", "Share NIPT oop"))
split_by_risk <- t(split_by_risk)
split_by_risk <- data.frame(split_by_risk)
write.csv(split_by_risk, file=paste0(RESULTS, "/estimate_w_xs/split_by_risk.csv"), na=".", row.names=TRUE)


# Save Counterfactuals table as .csv
if (use_full_sample == 1) {
  write.csv(out_cf, file=paste0(RESULTS, "/estimate_w_xs/cf_outcomes_full.csv"), na=".", row.names=TRUE)
} else {
  write.csv(out_cf, file=paste0(RESULTS, "/estimate_w_xs/cf_outcomes.csv"), na=".", row.names=TRUE)
}


### Counterfactuals of varying lower bound for NIPT coverage ####
N_grid <- 50
grid <- seq(from = 0, to = 1000, by = (1000 - 0) / N_grid)
vary_lb_cf <- matrix(NA, nrow = N_grid + 1, ncol = 10)

# loop over values of lower bound and save bad outcomes/govt_cost
for (i in 1:(N_grid + 1)){
  print(i)
  lb <- grid[i]
  if (grid[i] == 0) {
    lb <- 2
  }
  vary_lb_cf[i,1] = lb
  # Implement settings
  set <- all %>%
    mutate(oop_i = ifelse(fetus_risk <= lb, 0, 567.50)) %>% 
    mutate(oop2_i = 0) 
  if ((lb) == 2) {
    set <- all %>% mutate(oop_i =  567.50)
  }
  if ((lb) == 1000) {
    set <- all %>% mutate(oop_i =  0)
  }
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  out_shares <- outcomes_shares(data=out)
  
  # save bad outcome shares (+ separately by type of bad outcome), gcost, and testing
  vary_lb_cf[i,2] = out_shares[13] # bad outcome share
  vary_lb_cf[i,3] = out_shares[14] # gcost
  vary_lb_cf[i,4] = out_shares[1] # subt
  vary_lb_cf[i,5] = out_shares[2] # nipt only
  vary_lb_cf[i,6] = out_shares[3] # invasive only
  vary_lb_cf[i,7] = out_shares[4] # both tests
  vary_lb_cf[i,8] = out_shares[12] # terminated pregnancy without chrom abn
  vary_lb_cf[i,9] = out_shares[7] # live birth w CA & a >= c
  vary_lb_cf[i,10] = out_shares[17] - surplus_norm # consumer surplus
  
  rm(set, dec, out, out_shares)
}


### Counterfactuals of varying upper bound for NIPT coverage ####
N_grid <- 99
grid <- seq(from = 2, to = 200, by = (200 - 2) / N_grid)
vary_ub_cf <- matrix(NA, nrow = N_grid + 1, ncol = 10)

# loop over values of upper bound and save bad outcomes/govt_cost
for (i in 1:(N_grid+1)){
  print(i)
  ub <- grid[i]
  vary_ub_cf[i,1] = ub
  # Implement settings
  set <- all %>%
    mutate(oop_i = ifelse((fetus_risk <= 200) & (fetus_risk >= ub), 0, 567.50)) %>% 
    mutate(oop2_i = 0) 
  if ((ub+1) == 200) {
    set <- all %>% mutate(oop_i =  567.50)
  }
  
  # Predict decisions
  dec <- model_decisions_counterfactuals(par=est_par, data=set, prior=quo(p_i))
  
  # Calculate outcomes 
  out <- outcomes(par=est_par, data=dec, prior=quo(p_i))
  
  # Calculate shares of outcomes 
  out_shares <- outcomes_shares(data=out)
  
  # save bad outcome shares (+ separately by type of bad outcome), gcost, and testing
  vary_ub_cf[i,2] = out_shares[13] # bad outcome share
  vary_ub_cf[i,3] = out_shares[14] # gcost
  vary_ub_cf[i,4] = out_shares[1] # subt
  vary_ub_cf[i,5] = out_shares[2] # nipt only
  vary_ub_cf[i,6] = out_shares[3] # invasive only
  vary_ub_cf[i,7] = out_shares[4] # both tests
  vary_ub_cf[i,8] = out_shares[12] # terminated pregnancy without chrom abn
  vary_ub_cf[i,9] = out_shares[7] # live birth w CA & a >= c
  vary_ub_cf[i,10] = out_shares[17] - surplus_norm # consumer surplus
  

  rm(set, dec, out, out_shares)
}
  
print(vary_lb_cf)
print(vary_ub_cf)

if (use_full_sample == 1) {
  write.csv(vary_lb_cf, file=paste0(RESULTS, "/estimate_w_xs/vary_lb_cf_full.csv"), na=".", row.names=TRUE)
  write.csv(vary_ub_cf, file=paste0(RESULTS, "/estimate_w_xs/vary_ub_cf_full.csv"), na=".", row.names=TRUE)

} else {
  write.csv(vary_lb_cf, file=paste0(RESULTS, "/estimate_w_xs/vary_lb_cf.csv"), na=".", row.names=TRUE)
  write.csv(vary_ub_cf, file=paste0(RESULTS, "/estimate_w_xs/vary_ub_cf.csv"), na=".", row.names=TRUE)
}



#### CA/preg share by bin ###
analysis_merge <- analysis %>%
  dplyr::select(pregnancy,bin_number)
all <- left_join(all, analysis_merge, by = "pregnancy")

all_ca <- all %>%
  filter(sim_positive == 1) %>% 
  dplyr::select(pregnancy, sim_wgt, p_i, fetus_risk, bin_number)

if (use_full_sample == 1){
  write.dta(all_ca, file=paste0(TEMP, "/estimate_w_xs/all_ca_full.dta"))
} else {
  write.dta(all_ca, file=paste0(TEMP, "/estimate_w_xs/all_ca.dta"))
  
}


print("End time")
end_time <- Sys.time()

print(end_time)
time_diff <- end_time - start_time
print(time_diff)

#### EXIT ####

sink()
#rm(list = ls())



  